Skip to content

Commit

Permalink
[fix] masked_scatter_: non-contiguous self (#100232)
Browse files Browse the repository at this point in the history
Fixes #99638

Pull Request resolved: #100232
Approved by: https://github.com/ngimel
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Apr 28, 2023
1 parent 9cd48b0 commit 61dffa6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,8 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
// order of indexing matters
.enforce_linear_iteration()
.add_output(self)
.add_input(*b_mask)
.build();
Expand Down
34 changes: 34 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5414,6 +5414,40 @@ def test_item(self, device, dtype):
t = torch.ones((), device=device, dtype=dtype)
self.assertEqual(1, t.item())

@onlyNativeDeviceTypes
def test_masked_scatter_inplace_noncontiguous(self, device):
t = torch.zeros(5, 2, dtype=torch.long, device=device)
t_non_contig = t.transpose(0, 1)
t_contig = t_non_contig.contiguous()

assert t_contig.is_contiguous()
assert not t_non_contig.is_contiguous()

mask = torch.tensor([[False, True], [False, True], [False, False], [True, True], [True, True]], device=device)
mask_non_contig = mask.transpose(0, 1)
mask_contig = mask_non_contig.contiguous()

assert mask_contig.is_contiguous()
assert not mask_non_contig.is_contiguous()

# source is always converted to contiguous by the op.
source = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 9]], device=device)

# t: contig, mask: contig
expected = t_contig.masked_scatter_(mask_contig, source)

# t: non-contig, mask: non-contig
actual = t_non_contig.masked_scatter_(mask_non_contig, source)
self.assertEqual(actual, expected)

# t: contig, mask: non-contig
actual = t_contig.masked_scatter_(mask_non_contig, source)
self.assertEqual(actual, expected)

# t: non-contig, mask: contig
actual = t_non_contig.masked_scatter_(mask_contig, source)
self.assertEqual(actual, expected)


# Tests that compare a device's computation with the (gold-standard) CPU's.
class TestDevicePrecision(TestCase):
Expand Down

0 comments on commit 61dffa6

Please sign in to comment.