Skip to content

Commit

Permalink
[MPS] Fix bug when value is of complex (pytorch#111937)
Browse files Browse the repository at this point in the history
When the value of `fill` is of complex, this line `value.toDouble() == 0.0` will error out saying that converting complex to double will cause overflow. So we should firstly handle the complex value and then enter this condition.

Pull Request resolved: pytorch#111937
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#111885
  • Loading branch information
qqaatw authored and xuhancn committed Nov 8, 2023
1 parent 21ff9e6 commit 2987df3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/mps/operations/ConstantOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) {
}

Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
// check if it's possible to use fillBuffer() to fill the Tensor's storage
if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
return self;

if (isComplexType(self.scalar_type())) {
auto self_as_real = at::view_as_real(self);
auto self_as_real_real = self_as_real.select(self.dim(), 0);
Expand All @@ -104,6 +100,10 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) {
fill_scalar_mps_impl(self_as_real_imag, 0.0f);
return self;
}
// check if it's possible to use fillBuffer() to fill the Tensor's storage
if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
return self;

return fill_scalar_mps_impl(self, value);
}

Expand Down
13 changes: 6 additions & 7 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,19 +1407,18 @@ def testNumbersGPU(self):

def test_fill(self):

def helper(val, shape):
tensor = torch.zeros(shape, device='mps')
def helper(val, shape, dtype):
tensor = torch.zeros(shape, device='mps', dtype=dtype)
tensor_mps = tensor.fill_(val)
tensor_mps = torch.tanh(tensor_mps)

tensor_0 = torch.zeros(shape, device='cpu')
tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
tensor_cpu = tensor_0.fill_(val)
tensor_cpu = torch.tanh(tensor_cpu)

self.assertEqual(tensor_mps, tensor_cpu)

helper(0, [1024])
helper(0.2, [2, 3])
helper(0, [1024], torch.float32)
helper(0.2, [2, 3], torch.float32)
helper(0.2 + 0.5j, [2, 3], torch.complex64)

def test_fill_storage_offset(self):
shape = [2, 10]
Expand Down

0 comments on commit 2987df3

Please sign in to comment.