Skip to content

Commit

Permalink
[MPS] Fix topk for empty tensors and k=0
Browse files Browse the repository at this point in the history
  • Loading branch information
soof-golan committed Jan 9, 2023
1 parent 73e5379 commit 4c35451
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 24 deletions.
32 changes: 32 additions & 0 deletions aten/src/ATen/native/mps/operations/Shape.mm
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
namespace at {
namespace native {

// Produces a shape with the `dim` dimension set to 0.
std::vector<int64_t> getTopK0Shape(IntArrayRef sizes, const int64_t dim_) {
const int sz = sizes.size();
if (sz == 0) {
return {0};
}
const int64_t dim = maybe_wrap_dim(dim_, sz);
std::vector<int64_t> numbers(sz);

for (int i = 0; i < sz; i++) {
const int64_t sz_i = i != dim ? sizes[i] : 0;
numbers[i] = sz_i;
}
return numbers;
}

// topk
TORCH_IMPL_FUNC(topk_out_mps)
(const Tensor& self,
Expand All @@ -33,6 +49,22 @@
indices.zero_();
return;
}
// Handle empty tensors
if (self.numel() == 0)
{
values.copy_(self);
indices.copy_(values.toType(at::ScalarType::Long));
return;
}
// Handle k == 0 case. Needed because MPSGraph does not support k == 0.
if (k == 0)
{
const auto out_shape = getTopK0Shape(self.sizes(), dim);
values.resize_(out_shape);
indices.copy_(values.toType(at::ScalarType::Long));
return;
}

MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph
{
Expand Down
64 changes: 40 additions & 24 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4092,30 +4092,6 @@ def test_assert_topk(self):
ys_mps = ys_cpu.to('mps')
self.assertEqual(ys_cpu.topk(16), ys_mps.topk(16))

def test_topk(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
for largest_val in [True, False]:
if (type(shape) == tuple):
for curr_dim in range(0, len(shape)):
dim_size = shape[curr_dim]
for k in range(1, dim_size + 1):
topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest_val)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest_val)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
else:
for k in range(1, shape):
topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest_val)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest_val)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)

helper(2)
helper((5, 1))
helper((1, 5))
helper((5, 9, 7, 4))

def test_upsample_nearest2d(self):
def helper(N, C, H, W):
Expand Down Expand Up @@ -5829,6 +5805,46 @@ def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)


class TestTopK(TestCase):
def _test_topk(self, shape, largest):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
if isinstance(shape, tuple):
for curr_dim, dim_size in enumerate(shape):
for k in range(1, dim_size + 1):
topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
else:
for k in range(1, shape):
topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)

def test_topk(self):
largest_vals = [True, False]
shapes = [
# Zero Element Tensors
0,
(1, 0),
(0, 1),
(1, 0, 1),
# Multiple Element Tensors
1,
2,
(5, 1),
(1, 5),
(5, 9, 7, 4),
]

for shape in shapes:
for largest_val in largest_vals:
with self.subTest(shape=shape, largest_val=largest_val):
self._test_topk(shape, largest_val)

class TestNNMPS(NNTestCase):

def _create_basic_net(self):
Expand Down

0 comments on commit 4c35451

Please sign in to comment.