From 53af3b7a991319fc5101d58c97a4de59eaec14ab Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 19 Oct 2022 07:35:05 -0700 Subject: [PATCH] [MPS] Do not dispatch empty job in `bitwise_not` Follows the pattern from https://github.com/pytorch/pytorch/pull/85285 and returns before computing dispatching an empty metal kernel for bitwise not operation. Fixes crash when invoked with empty MPS tensor on AMD GPU --- aten/src/ATen/native/mps/operations/BitwiseOps.mm | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/mps/operations/BitwiseOps.mm b/aten/src/ATen/native/mps/operations/BitwiseOps.mm index 411f35e49b3c7..5b57693296b12 100644 --- a/aten/src/ATen/native/mps/operations/BitwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/BitwiseOps.mm @@ -302,6 +302,10 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot } return output_; } + uint32_t length = output.numel(); + if (length == 0) { + return output_; + } using namespace at::mps; MPSStream* stream = getCurrentMPSStream(); id cplState = getCPLState(MPSDevice::getInstance()->device(), @@ -309,7 +313,6 @@ void handle_tensor_scalar_binary_op(const at::Tensor& self, const at::Scalar& ot getMetalType(self), getMetalType(self), "bitwise_not"); - uint32_t length = output.numel(); dispatch_sync(stream->queue(), ^(){ id buffer = stream->commandBuffer(); id commandEncoder = [buffer computeCommandEncoder];