Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ __global__ void compute_grad_weight(
int64_t* segment_offsets,
int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
int padding_idx,
const int64_t stride_warped) {

using accscalar_t = acc_type<scalar_t, true>;
Expand All @@ -142,10 +141,8 @@ __global__ void compute_grad_weight(
accscalar_t weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
const int64_t target_row = indices[idx];
if (target_row != padding_idx) {
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
weight += gradOutput[target_row * stride + startFeature] * scale;
}
const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
weight += gradOutput[target_row * stride + startFeature] * scale;
}
grad_weight_per_segment[id * stride + startFeature] = weight;
}
Expand All @@ -157,6 +154,7 @@ __global__ void sum_and_scatter(
int64_t* segment_offsets, int64_t num_of_segments,
const acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
const int64_t padding_idx,
const int64_t stride_warped) {

const int gid = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -175,8 +173,10 @@ __global__ void sum_and_scatter(
for (int idx=idx_begin; idx < idx_end; ++idx) {
weight += grad_weight_per_segment[idx*stride + startFeature];
}
const int weightRow = input[segment_offsets[id]] * stride;
gradWeight[weightRow + startFeature] = weight;
int64_t target_row = input[segment_offsets[id]];
if (target_row != padding_idx) {
gradWeight[target_row * stride + startFeature] = weight;
}
}

} // anon namespace
Expand Down Expand Up @@ -299,7 +299,6 @@ Tensor embedding_backward_cuda_kernel(
partial_segment_offset.data_ptr<int64_t>(),
num_of_partial_segments,
grad_weight_per_segment.data_ptr<partial_weight_t>(),
padding_idx,
stride_warped);
}
THCudaCheck(cudaGetLastError());
Expand All @@ -314,7 +313,9 @@ Tensor embedding_backward_cuda_kernel(
segment_offsets.data_ptr<int64_t>(),
num_of_segments, grad_weight_per_segment.data_ptr<partial_weight_t>(),
partials_per_segment_offset.data_ptr<int64_t>(),
num_of_partial_segments, stride_warped);
num_of_partial_segments,
padding_idx,
stride_warped);
THCudaCheck(cudaGetLastError());
});
return grad_weight;
Expand Down
104 changes: 52 additions & 52 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,58 +2269,6 @@ def test_move_sparse_half_embedding(self):
embedding.to('cpu')
self.assertEqual(embedding.weight.device.type, 'cpu')

def test_embedding_padding_idx(self):
embedding = nn.Embedding(10, 20, padding_idx=0)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
output = embedding(input)
self.assertEqual(output[0][0].sum(), 0)
self.assertEqual(output[1][2].sum(), 0)

embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long)
output = embedding(input)
self.assertEqual(output[0][0].sum(), 0)
self.assertEqual(output[1][2].sum(), 0)

# negative indexing check for padding_idx
# padding_idx=-2, num_embeddings=10 ==> index 8 padded
embedding = nn.Embedding(10, 20, padding_idx=-2)
input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long)
output = embedding(input)
self.assertEqual(output[0][2].sum(), 0)
self.assertEqual(output[1][1].sum(), 0)

embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True)
input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long)
output = embedding(input)
self.assertEqual(output[0][2].sum(), 0)
self.assertEqual(output[1][1].sum(), 0)

# out of bounds check for padding_idx
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=25)
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=-25)

# test backward when input contains padding_idx
padding_idx = 0
embedding = nn.Embedding(5, 2, padding_idx=padding_idx)
for n in (1, 2):
for other_indices in ([], [1, 3], [2]):
indices = torch.tensor(other_indices + [padding_idx] * n, dtype=torch.long)
pre = embedding.weight[padding_idx].clone()
embedding(indices).sum().backward()
after = (embedding.weight + embedding.weight.grad)[padding_idx]
embedding.zero_grad()
self.assertEqual(after, pre)

# test double backward
emb_sum = embedding(indices).sum()
emb_grad = torch.autograd.grad(outputs=emb_sum, inputs=list(embedding.parameters()), retain_graph=True)
scalar = emb_grad[0].sum() + emb_sum
scalar.backward()
after = (embedding.weight + embedding.weight.grad)[padding_idx]
embedding.zero_grad()
self.assertEqual(after, pre)

def test_embedding_max_norm(self):
embedding = nn.Embedding(22, 5, max_norm=1.0)
input = torch.tensor([2, 8, 8, 6], dtype=torch.long)
Expand Down Expand Up @@ -8808,6 +8756,58 @@ def test_embedding_backward(self, device, dtype):
self.assertEqual(embedding.weight.grad._indices(), tensorTwice)
self.assertEqual(embedding.weight.grad._values(), onesTwice)

def test_embedding_padding_idx(self, device):
embedding = nn.Embedding(10, 20, padding_idx=0).to(device)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
output = embedding(input)
self.assertEqual(output[0][0].sum(), 0)
self.assertEqual(output[1][2].sum(), 0)

embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True).to(device)
input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long).to(device)
output = embedding(input)
self.assertEqual(output[0][0].sum(), 0)
self.assertEqual(output[1][2].sum(), 0)

# negative indexing check for padding_idx
# padding_idx=-2, num_embeddings=10 ==> index 8 padded
embedding = nn.Embedding(10, 20, padding_idx=-2).to(device)
input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
output = embedding(input)
self.assertEqual(output[0][2].sum(), 0)
self.assertEqual(output[1][1].sum(), 0)

embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True).to(device)
input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long).to(device)
output = embedding(input)
self.assertEqual(output[0][2].sum(), 0)
self.assertEqual(output[1][1].sum(), 0)

# out of bounds check for padding_idx
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=25)
self.assertRaises(AssertionError, nn.Embedding, num_embeddings=10, embedding_dim=20, padding_idx=-25)

# test backward when input contains padding_idx
padding_idx = 0
embedding = nn.Embedding(5, 2, padding_idx=padding_idx).to(device)
for n in (1, 2, 1000): # Need large N to trigger all the methods we have implemented
for other_indices in ([], [1, 3], [2]):
indices = torch.tensor(other_indices + [padding_idx] * n, dtype=torch.long).to(device)
pre = embedding.weight[padding_idx].clone()
embedding(indices).sum().backward()
after = (embedding.weight + embedding.weight.grad)[padding_idx]
embedding.zero_grad()
self.assertEqual(after, pre)

# test double backward
emb_sum = embedding(indices).sum()
emb_grad = torch.autograd.grad(outputs=emb_sum, inputs=list(embedding.parameters()), retain_graph=True)
scalar = emb_grad[0].sum() + emb_sum
scalar.backward()
after = (embedding.weight + embedding.weight.grad)[padding_idx]
embedding.zero_grad()
self.assertEqual(after, pre)

@dtypesIfCUDA(torch.half, torch.float)
@dtypes(torch.float)
def test_softmax_backward(self, device, dtype):
Expand Down