Skip to content

Commit

Permalink
Add padding_idx argument to EmbeddingBag (#49237)
Browse files Browse the repository at this point in the history
Summary:
This PR adds a `padding_idx` parameter to `nn.EmbeddingBag` and `nn.functional.embedding_bag`. As with `nn.Embedding`'s `padding_idx` argument, if an embedding's index is equal to `padding_idx` it is ignored, so it is not included in the reduction.

This PR does not add support for `padding_idx` for quantized or ONNX `EmbeddingBag` for opset10/11 (opset9 is supported). In these cases, an error is thrown if `padding_idx` is provided.

Fixes #3194

Pull Request resolved: #49237

Reviewed By: walterddr, VitalyFedyunin

Differential Revision: D26948258

Pulled By: jbschlosser

fbshipit-source-id: 3ca672f7e768941f3261ab405fc7597c97ce3dfc
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Apr 14, 2021
1 parent f94c95a commit 3fe4718
Show file tree
Hide file tree
Showing 26 changed files with 927 additions and 293 deletions.
425 changes: 254 additions & 171 deletions aten/src/ATen/native/EmbeddingBag.cpp

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions aten/src/ATen/native/EmbeddingBag.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ void make_offset2bag_out(
const Tensor& indices,
const Tensor& offsets,
const int64_t mode,
const c10::optional<Tensor>& per_sample_weights);
const c10::optional<Tensor>& per_sample_weights,
const int64_t padding_idx = -1);

void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
Tensor& bag_size, Tensor& max_indices,
const Tensor &weight, const Tensor &indices,
const Tensor &offsets, const int64_t mode = 0,
const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
bool include_last_offset = false);
bool include_last_offset = false,
int64_t padding_idx = -1);
} // native
} // at
1 change: 0 additions & 1 deletion aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ Tensor embedding_backward_cuda_kernel(
const Tensor &count,
int64_t num_weights,
int padding_idx,
bool scale_grad_by_freq,
bool mode_mean,
const Tensor &offset2bag,
const Tensor &bag_size,
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Tensor embedding_backward_cuda_kernel(
const Tensor &count,
int64_t num_weights,
int padding_idx = -1,
bool scale_grad_by_freq = false,
bool mode_mean = false,
const Tensor &offset2bag = Tensor(),
const Tensor &bag_size = Tensor(),
Expand Down
185 changes: 116 additions & 69 deletions aten/src/ATen/native/cuda/EmbeddingBag.cu

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@
# The backward functions apply a check that these input tensors are contiguous.


- func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
- func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CPU: _embedding_bag_forward_only_cpu
CUDA: _embedding_bag_forward_only_cuda
Expand All @@ -1472,21 +1472,27 @@

- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)

- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
# To keep backward and forward compatibility, and to avoid ambiguity with the
# original signature above, scale_grad_by_freq, mode, sparse,
# per_sample_weights, and include_last_offset parameters do not have default
# values. Once the original signature is removed, default values can be added.
- func: embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)

- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CPU: _embedding_bag_cpu
CUDA: _embedding_bag_cuda

- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights) -> Tensor
- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor

- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights) -> Tensor
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor

- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights) -> Tensor
- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
dispatch:
CPU: _embedding_bag_dense_backward_cpu
CUDA: _embedding_bag_dense_backward_cuda

- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode) -> Tensor
- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor
dispatch:
CPU: _embedding_bag_per_sample_weights_backward_cpu
CUDA: _embedding_bag_per_sample_weights_backward_cuda
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
("aten::cumprod_backward", datetime.date(2021, 5, 1)),
("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)),
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
]

def allow_listed(schema, allow_list):
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/api/functional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,8 @@ TEST_F(FunctionalTest, EmbeddingBag) {
auto offsets = torch::tensor({0,4}, torch::kLong);
auto weight = torch::empty({10, 3});
torch::nn::init::normal_(weight);
auto y = F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets));
auto y_exp = std::get<0>(torch::embedding_bag(weight, input, offsets, false, 0, false, torch::Tensor()));
auto y = F::embedding_bag(input, weight, F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets).padding_idx(4));
auto y_exp = std::get<0>(torch::embedding_bag(weight, input, offsets, false, 0, false, torch::Tensor(), false, 4));
ASSERT_TRUE(torch::allclose(y, y_exp));

// no options test
Expand Down
3 changes: 3 additions & 0 deletions test/cpp/api/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4568,6 +4568,9 @@ TEST_F(ModulesTest, PrettyPrintEmbeddingBag) {
ASSERT_EQ(
c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum))),
"torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum)");
ASSERT_EQ(
c10::str(EmbeddingBag(EmbeddingBagOptions(10, 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(5))),
"torch::nn::EmbeddingBag(num_embeddings=10, embedding_dim=2, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true, mode=kSum, padding_idx=5)");
}

TEST_F(ModulesTest, PrettyPrintL1Loss) {
Expand Down
55 changes: 55 additions & 0 deletions test/quantization/test_quantize_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3433,6 +3433,61 @@ def forward(self, indices1, offsets1, indices2, offsets2):
).run(m.graph)
m(*dummy_inputs)

# Ensure that attempting to quantize an EmbeddingBag throws an error if
# padding_idx is not None
@skipIfNoFBGEMM
def test_embedding_bag_padding_idx_error(self):
class M(torch.nn.Module):
def __init__(self, weights):
super(M, self).__init__()
self.embedding = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=True,
_weight=weights,
mode="sum",
padding_idx=0,
)

def forward(self, indices, offsets):
e = self.embedding(indices, offsets)
return e

weights = torch.randn(10, 12, dtype=torch.float32)
module = M(weights)

indices = torch.tensor([0, 1, 2, 3, 4])
offsets = torch.tensor([0, 2, 5])
dummy_inputs = (indices, offsets)

int4_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_4bit"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_4bit"
),
)
int8_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_byte"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_byte"
),
)

error_msg = r'Expected aten::embedding_bag padding_idx input to be None'
for trace, qconfig in itertools.product([True, False], [int4_qconfig, int8_qconfig]):
if trace:
m = torch.jit.trace(module, dummy_inputs)
else:
m = torch.jit.script(module)
m = prepare_jit(m, {"embedding": qconfig})
with self.assertRaisesRegex(RuntimeError, error_msg):
m = convert_jit(m)


class TestQuantizeJit(QuantizationTestCase):
@override_qengines
Expand Down

0 comments on commit 3fe4718

Please sign in to comment.