diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 365db61d06c0..dbf968084e6e 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -344,12 +344,9 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, using device_ptr = thrust::device_ptr; auto num_indices = indices.numel(); - auto indices_contig = indices.contiguous(); + auto indices_contig = std::get<0>(indices.sort()).contiguous(); auto indices_data = device_ptr(indices_contig.data_ptr()); - // FIXME: thrust::unique only removes consecutive elements that are equal. - // We have race conditions when indices contain duplicates which are not - // adjacent auto unique_indices = at::empty(indices.numel(), indices.options()); auto unique_data = device_ptr(unique_indices.data_ptr()); auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); diff --git a/test/test_nn.py b/test/test_nn.py index ccf4ea7aa8d1..2f1a1d9e49a0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2954,6 +2954,23 @@ def test_threshold_int(self): expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) self.assertEqual(F.threshold(x, 0, 99), expected) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_embedding_max_norm_unsorted_repeating_indices(self): + def create_embedding(device): + # Seed RNG so we get the same Embedding each time + torch.manual_seed(0) + return torch.nn.Embedding( + num_embeddings=20, + embedding_dim=64, + max_norm=1.0).to(device) + + ix = torch.arange(2, device='cpu', dtype=torch.long).repeat(2000) + out_cpu = create_embedding('cpu')(ix) + + ix = ix.to('cuda') + out = create_embedding('cuda')(ix) + self.assertEqual(out.cpu(), out_cpu) + def test_embedding_sparse_basic(self): embedding = nn.Embedding(10, 20, sparse=True) input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)