From 66505b64a5f26094b9b037c9777fec2706b2e5ad Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 12 Oct 2020 18:26:07 -0700 Subject: [PATCH] Fix incorrect CUDA `torch.nn.Embedding` result when max_norm is not None and indices are not sorted (#45248) Summary: Sorting indices before calling `thrust::unique` fixes the issue. Fixes https://github.com/pytorch/pytorch/issues/44792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45248 Reviewed By: mruberry Differential Revision: D24194696 Pulled By: ngimel fbshipit-source-id: ab59ef9d46b9917b1417bab25f80ce9780f0c930 --- aten/src/ATen/native/cuda/Embedding.cu | 5 +---- test/test_nn.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 365db61d06c0..dbf968084e6e 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -344,12 +344,9 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, using device_ptr = thrust::device_ptr; auto num_indices = indices.numel(); - auto indices_contig = indices.contiguous(); + auto indices_contig = std::get<0>(indices.sort()).contiguous(); auto indices_data = device_ptr(indices_contig.data_ptr()); - // FIXME: thrust::unique only removes consecutive elements that are equal. - // We have race conditions when indices contain duplicates which are not - // adjacent auto unique_indices = at::empty(indices.numel(), indices.options()); auto unique_data = device_ptr(unique_indices.data_ptr()); auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data); diff --git a/test/test_nn.py b/test/test_nn.py index ccf4ea7aa8d1..2f1a1d9e49a0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2954,6 +2954,23 @@ def test_threshold_int(self): expected = torch.tensor([99, 99, 99, 99, 1, 2, 3]) self.assertEqual(F.threshold(x, 0, 99), expected) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + def test_embedding_max_norm_unsorted_repeating_indices(self): + def create_embedding(device): + # Seed RNG so we get the same Embedding each time + torch.manual_seed(0) + return torch.nn.Embedding( + num_embeddings=20, + embedding_dim=64, + max_norm=1.0).to(device) + + ix = torch.arange(2, device='cpu', dtype=torch.long).repeat(2000) + out_cpu = create_embedding('cpu')(ix) + + ix = ix.to('cuda') + out = create_embedding('cuda')(ix) + self.assertEqual(out.cpu(), out_cpu) + def test_embedding_sparse_basic(self): embedding = nn.Embedding(10, 20, sparse=True) input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)