Skip to content

Commit

Permalink
Fix incorrect CUDA torch.nn.Embedding result when max_norm is not N…
Browse files Browse the repository at this point in the history
…one and indices are not sorted (#45248)

Summary:
Sorting indices before calling `thrust::unique` fixes the issue.
Fixes #44792

Pull Request resolved: #45248

Reviewed By: mruberry

Differential Revision: D24194696

Pulled By: ngimel

fbshipit-source-id: ab59ef9d46b9917b1417bab25f80ce9780f0c930
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Oct 13, 2020
1 parent 88dcb95 commit 66505b6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
5 changes: 1 addition & 4 deletions aten/src/ATen/native/cuda/Embedding.cu
Expand Up @@ -344,12 +344,9 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
using device_ptr = thrust::device_ptr<int64_t>;

auto num_indices = indices.numel();
auto indices_contig = indices.contiguous();
auto indices_contig = std::get<0>(indices.sort()).contiguous();
auto indices_data = device_ptr(indices_contig.data_ptr<int64_t>());

// FIXME: thrust::unique only removes consecutive elements that are equal.
// We have race conditions when indices contain duplicates which are not
// adjacent
auto unique_indices = at::empty(indices.numel(), indices.options());
auto unique_data = device_ptr(unique_indices.data_ptr<int64_t>());
auto end = thrust::unique_copy(policy, indices_data, indices_data + num_indices, unique_data);
Expand Down
17 changes: 17 additions & 0 deletions test/test_nn.py
Expand Up @@ -2954,6 +2954,23 @@ def test_threshold_int(self):
expected = torch.tensor([99, 99, 99, 99, 1, 2, 3])
self.assertEqual(F.threshold(x, 0, 99), expected)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_embedding_max_norm_unsorted_repeating_indices(self):
def create_embedding(device):
# Seed RNG so we get the same Embedding each time
torch.manual_seed(0)
return torch.nn.Embedding(
num_embeddings=20,
embedding_dim=64,
max_norm=1.0).to(device)

ix = torch.arange(2, device='cpu', dtype=torch.long).repeat(2000)
out_cpu = create_embedding('cpu')(ix)

ix = ix.to('cuda')
out = create_embedding('cuda')(ix)
self.assertEqual(out.cpu(), out_cpu)

def test_embedding_sparse_basic(self):
embedding = nn.Embedding(10, 20, sparse=True)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
Expand Down

0 comments on commit 66505b6

Please sign in to comment.