Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant][pyper] Support aten::embedding_bag quantization in graph mode #43989

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 20 additions & 13 deletions test/quantization/test_quantize_jit.py
Expand Up @@ -2901,12 +2901,14 @@ def __init__(self, weights):
self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=False,
_weight=weights,
mode='sum')

self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=False,
_weight=weights,
mode='sum')

Expand All @@ -2917,21 +2919,26 @@ def forward(self, indices1, offsets1, indices2, offsets2):

weights = torch.randn(10, 12, dtype=torch.float32)
module = M(weights)
m = torch.jit.script(module)

indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
offsets = torch.tensor([0, 19, 20, 28, 28, 32])

from torch.quantization import QConfigDynamic, PlaceholderObserver
int4_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
custom_op_name="embedding_bag_4bit"),
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"))
int8_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
custom_op_name="embedding_bag_byte"),
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"))
m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig})
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
.run(m.graph)
dummy_inputs = (indices, offsets, indices, offsets)
for trace in [True, False]:
if trace:
m = torch.jit.trace(module, dummy_inputs)
else:
m = torch.jit.script(module)
from torch.quantization import QConfigDynamic, PlaceholderObserver
int4_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
custom_op_name="embedding_bag_4bit"),
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have a placeholder observer for weights?. My understanding is that we can use real observers for 8 bit but not for 4 bit currently. Is that correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently use real observers and torchbind classes for eager mode 8-bit embedding quant currently. For graph mode we implemented this initially using the custom prepack ops for PyPer for 8bit and 4bit, to be consistent with C2.
Going forward, in fx we can implement embeddingbag quantization using observers. I feel it is a bit of an overkill to update this code to use observers for 8-bit and placeholder observers for 4-bit. Let me know your thoughts.

int8_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float,
custom_op_name="embedding_bag_byte"),
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"))
m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig})
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
.run(m.graph)

class TestQuantizeJit(QuantizationTestCase):
@override_qengines
Expand Down
21 changes: 18 additions & 3 deletions torch/csrc/jit/passes/quantization/helper.cpp
Expand Up @@ -51,6 +51,7 @@ std::vector<std::string> _dynamic_quantizable_call_funcs = {

std::vector<std::string> _dynamic_quantizable_aten_funcs = {
"linear",
"embedding_bag",
};

// These are the prim::CallFunctions that doesn't require observation and
Expand Down Expand Up @@ -259,10 +260,16 @@ bool matchArgPattern(
bool isWeight(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs(
{{"conv1d", 1}, {"conv2d", 1}, {"conv3d", 1}, {"linear", 1}}),
// ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: aten?

// %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
AtenFuncArgs({{"conv1d", 1},
{"conv2d", 1},
{"conv3d", 1},
{"linear", 1},
{"embedding_bag", 0}}),
// embedding_bag - prim::CallFunction(%func, %input.1, %weight,
// %offsets.1, %7, %8, %9, %10, %9, %per_sample_weights.1, %13)
// %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
// %per_sample_weights.1, %include_last_offset)
CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
return result;
}
Expand All @@ -276,6 +283,14 @@ bool isBiasOfConvOrLinear(Value* v) {
return result;
}

bool isEmbeddingBagNonInput(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
CallFuncArgs({}));
return result;
}

c10::optional<Use> getClampScalarInputUse(Value* v) {
for (const auto& use : v->uses()) {
for (const auto& aten_func : _clamp_funcs) {
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/passes/quantization/helper.h
Expand Up @@ -29,6 +29,8 @@ TORCH_API bool isWeight(Value* v);
// quantize
TORCH_API bool isBiasOfConvOrLinear(Value* v);

TORCH_API bool isEmbeddingBagNonInput(Value* v);

// Get the use as scalar input of clamp ops for the input value
c10::optional<Use> getClampScalarInputUse(Value* v);

Expand Down Expand Up @@ -112,6 +114,13 @@ bool matchCallFuncToUse(
const std::string& func_name,
c10::optional<int> nth_arg);

// Check if `use` is a AtenFunction of name `func_name` and if value
// `v` is the nth argument (if provided) of the function
bool matchAtenFuncToUse(
const Use& use,
const std::string& func_name,
c10::optional<int> nth_arg);

// =========== helper functions for Block =========
// checks if a block will always raise an Exception
TORCH_API bool alwaysRaisesException(Block* block);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/passes/quantization/insert_observers.cpp
Expand Up @@ -1162,7 +1162,8 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(
const QConfig& qconfig) {
if (isBiasOfConvOrLinear(v) ||
!(v->type()->isSubtypeOf(TensorType::get()) ||
v->type()->isSubtypeOf(ListType::ofTensors()))) {
v->type()->isSubtypeOf(ListType::ofTensors())) ||
isEmbeddingBagNonInput(v)) {
return false;
}
// For dynamic quantization we only insert observers at the input
Expand Down
60 changes: 43 additions & 17 deletions torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
Expand Up @@ -308,44 +308,70 @@ Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
// We expect that the output of the weight observer will be consumed by the
// embedding_bag operator.
for (const Use& use : uses) {
if (matchCallFuncToUse(use, "embedding_bag", 2)) {
if (matchCallFuncToUse(use, "embedding_bag", 2) ||
matchAtenFuncToUse(use, "embedding_bag", 0)) {
embedding_bag_float_op = use.user;
}
}
TORCH_CHECK(
embedding_bag_float_op->inputs().size() == 11,
"Expecting FP EmbeddingBag operator to have 11 inputs");

// Insert prepack op
Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs);
g->insertNode(prepack);

std::vector<Value*> embedding_bag_inputs =
embedding_bag_float_op->inputs().vec();

std::vector<Value*> qembedding_bag_inputs = {prepack->output()};
const auto inputs_size = embedding_bag_float_op->inputs().size();
const bool is_aten_op =
embedding_bag_float_op->kind() == Symbol::aten("embedding_bag");
// Create and insert quantized embedding op.
Value* none = g->insertConstant(IValue());
Value* zero = g->insertConstant(IValue(0));

std::vector<Value*> qembedding_bag_inputs = {
/* weight */ prepack->output(),
/* indices */ embedding_bag_inputs[1],
/* offsets */ embedding_bag_inputs[3],
/* scale_grad_by_freq */ embedding_bag_inputs[6],
/* mode */ zero,
/* sparse */ embedding_bag_inputs[8],
/* per_sample_weights_ */ embedding_bag_inputs[9]};
if (is_aten_op) {
TORCH_CHECK(
inputs_size == 8,
"Expecting FP aten::embedding_bag operator to have 8 inputs");
// input 0 is the output of prepack op.
// Last input is added after we account for extra input in 4-bit case.
for (auto i = 1; i < inputs_size - 1; ++i) {
qembedding_bag_inputs.push_back(embedding_bag_inputs[i]);
}
} else {
TORCH_CHECK(
inputs_size == 11,
"Expecting F.embedding_bag operator to have 11 inputs");
qembedding_bag_inputs.push_back(embedding_bag_inputs[1]); // indices
qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets
qembedding_bag_inputs.push_back(
embedding_bag_inputs[6]); // scale_grad_by_freq
qembedding_bag_inputs.push_back(zero); // zero
qembedding_bag_inputs.push_back(embedding_bag_inputs[8]); // sparse
qembedding_bag_inputs.push_back(
embedding_bag_inputs[9]); // per_sample_weights
}

if (op_name == "embedding_bag_4bit") {
// 4-bit op has an extra input compressed_indices_mapping
qembedding_bag_inputs.push_back(none);
}
qembedding_bag_inputs.push_back(embedding_bag_inputs[10]);
qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 1]);

Node* qembedding_bag =
g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs);
g->insertNode(qembedding_bag);

embedding_bag_float_op->output()->replaceAllUsesWith(
if (is_aten_op) {
WithInsertPoint ins(embedding_bag_float_op);
g->insertNode(qembedding_bag);
// Verify that the outputs (apart from index 0) have no uses in the graph.
for (auto i = 1; i < embedding_bag_float_op->outputs().size(); ++i) {
TORCH_CHECK(
!embedding_bag_float_op->output(i)->hasUses(),
"Expected aten::embedding_bag to only have use for its first output.");
}
} else {
g->insertNode(qembedding_bag);
}
embedding_bag_float_op->output(0)->replaceAllUsesWith(
qembedding_bag->output());
embedding_bag_float_op->removeAllInputs();
embedding_bag_float_op->destroy();
Expand Down