Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] torch.nn.functional.embedding -> padding_idx behavior #46714

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions aten/src/ATen/native/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,27 @@ Tensor embedding(const Tensor & weight, const Tensor & indices,
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding", indices_arg, kLong);

auto zerofill_padding = [&](Tensor& embedding) {
if (padding_idx >= 0) {
embedding.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
}
};

// TODO: use tensor.index() after improving perf
if (indices.dim() == 1) {
return weight.index_select(0, indices);
auto out = weight.index_select(0, indices);
zerofill_padding(out);
return out;
}

auto size = indices.sizes().vec();
for (auto d : weight.sizes().slice(1)) {
size.push_back(d);
}
return weight.index_select(0, indices.reshape(-1)).view(size);

auto out = weight.index_select(0, indices.reshape(-1));
zerofill_padding(out);
return out.view(size);
}

Tensor embedding_backward(
Expand Down
16 changes: 16 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3055,6 +3055,13 @@ def test_embedding_functional(self):
res_F = F.embedding(a, embeddings)
self.assertEqual(res_old, res_F)

embed_old = torch.nn.Embedding(4, 3)
embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
res_old = embed_old(a)
res_F = F.embedding(a, embeddings, padding_idx=2)

self.assertEqual(res_old, res_F)

@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
' with instruction set support avx2 or newer.')
Expand Down Expand Up @@ -10707,6 +10714,15 @@ def fn(weight):
fn = fn_wrapper(device)
_assertGradAndGradgradChecks(self, fn, (weight, ))

def fn_wrapper(device):
def padding_fn(weight):
inp = torch.tensor([[0, 1, 1, 2], [1, 1, 0, 2]], dtype=torch.long).to(device)
return torch.nn.functional.embedding(inp, weight, padding_idx=1)
return padding_fn

fn = fn_wrapper(device)
_assertGradAndGradgradChecks(self, fn, (weight, ))

def test_embedding_scalar_weight_error(self, device):
indices = torch.rand(2, 2, device=device).long()
weight = torch.tensor(1.0, device=device)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)

- name: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
grad_output: embedding_dense_double_backward(grad, indices)
grad_output: embedding_dense_double_backward(grad, indices, padding_idx)
indices: non_differentiable

- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2668,16 +2668,18 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
return at::constant_pad_nd(grad, negated_pad, 0);
}

Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices) {
// since first backward takes care of padding_idx
// and scaling by frequency, we don't need to worry
// about it here.
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
// since first backward takes care of scaling by frequency,
// we don't need to worry about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));

// reshape gradient as per the shape of indices
auto size = indices.sizes().vec();
size.push_back(-1);

if (padding_idx >= 0) {
gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
}
return gg_weight.view(size);
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ at::Tensor logdet_backward(const at::Tensor & grad, const at::Tensor& self, cons
at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor& self, const at::Tensor& signdet, const at::Tensor& logabsdet);
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape);
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices);
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);

Expand Down