Skip to content

Commit

Permalink
[quant][graphmode][fx][fix] scalar as first input for add/mul (#46751)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #46751

Currently we assume the first input for add/mul is node (Tensor), but it might not be the case

Test Plan:
python test/test_quantization.py TestQuantizeFxOps.test_quantized_add
python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul
python test/test_quantization.py TestQuantizeFxOps.test_quantized_add_relu
python test/test_quantization.py TestQuantizeFxOps.test_quantized_mul_relu

Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D24494456

fbshipit-source-id: ef5e23ba60eb22a57771791f4934306b25c27c01
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 28, 2020
1 parent 810c68f commit 5a8198e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 16 deletions.
11 changes: 11 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qadd.cpp
Expand Up @@ -243,6 +243,15 @@ Tensor qadd_scalar(Tensor qa, Scalar b) {
return _add_scalar_out<ReLUFused>(qc, qa, b);
}

template <bool ReLUFused = false>
Tensor qadd_scalar2(Scalar b, Tensor qa) {
TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
qa.qscheme() == kPerTensorSymmetric,
"Only per tensor quantization is supported in Add.");
auto qc = at::empty_like(qa, qa.suggest_memory_format());
return _add_scalar_out<ReLUFused>(qc, qa, b);
}

template <bool ReLUFused = false>
Tensor qadd_scalar_out(Tensor qa, Scalar b, Tensor out) {
check_inputs(qa, out);
Expand All @@ -269,10 +278,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::add"), TORCH_FN(qadd</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.out"), TORCH_FN(qadd_out</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar"), TORCH_FN(qadd_scalar</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar2"), TORCH_FN(qadd_scalar2</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add.Scalar_out"), TORCH_FN(qadd_scalar_out</*ReLUFused=*/false>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu"), TORCH_FN(qadd</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.out"), TORCH_FN(qadd_out</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar"), TORCH_FN(qadd_scalar</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar2"), TORCH_FN(qadd_scalar2</*ReLUFused=*/true>));
m.impl(TORCH_SELECTIVE_NAME("quantized::add_relu.Scalar_out"), TORCH_FN(qadd_scalar_out</*ReLUFused=*/true>));
// deprecated functions, kept for backward compatibility
m.impl(TORCH_SELECTIVE_NAME("quantized::add_out"), TORCH_FN(qadd_out</*ReLUFused=*/false>));
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qmul.cpp
Expand Up @@ -136,6 +136,18 @@ class QMulScalar final {
}
};

template <bool ReLUFused = false>
class QMulScalar2 final {
public:
static Tensor run(Scalar b, Tensor qa) {
TORCH_CHECK(qa.qscheme() == kPerTensorAffine ||
qa.qscheme() == kPerTensorSymmetric,
"Only per tensor quantization is supported in Mul.");
auto qc = at::empty_like(qa, qa.suggest_memory_format());
return _mul_scalar_out<ReLUFused>(qc, qa, b);
}
};

template <bool ReLUFused = false>
class QMulScalarOut final {
public:
Expand Down Expand Up @@ -176,10 +188,12 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::mul"), TORCH_FN(QMul</*ReLUFused=*/false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul.out"), TORCH_FN(QMulOut</*ReLUFused=*/false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar"), TORCH_FN(QMulScalar</*ReLUFused=*/false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar2"), TORCH_FN(QMulScalar2</*ReLUFused=*/false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul.Scalar_out"), TORCH_FN(QMulScalarOut</*ReLUFused=*/false>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu"), TORCH_FN(QMul</*ReLUFused=*/true>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.out"), TORCH_FN(QMulOut</*ReLUFused=*/true>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar"), TORCH_FN(QMulScalar</*ReLUFused=*/true>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar2"), TORCH_FN(QMulScalar2</*ReLUFused=*/true>::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_relu.Scalar_out"), TORCH_FN(QMulScalarOut</*ReLUFused=*/true>::run));
// deprecated functions, kept for backward compatibility
m.impl(TORCH_SELECTIVE_NAME("quantized::mul_out"), TORCH_FN(QMulOut</*ReLUFused=*/false>::run));
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/quantized/library.cpp
Expand Up @@ -23,9 +23,11 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar(Tensor qa, Scalar b) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar2(Scalar b, Tensor qa) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar(Tensor qa, Scalar b) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar2(Scalar b, Tensor qa) -> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.out(Tensor qa, Tensor qb, Tensor(a!) out) -> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::add_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out) -> Tensor(a!) out"));
// deprecated functions, kept for backward compatibility
Expand Down Expand Up @@ -142,10 +144,12 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar(Tensor qa, Scalar b)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar2(Scalar b, Tensor qa)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar(Tensor qa, Scalar b)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar2(Scalar b, Tensor qa)-> Tensor qc"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_relu.Scalar_out(Tensor qa, Scalar b, Tensor(a!) out)-> Tensor(a!) out"));
// deprecated functions, kept for backward compatibility
m.def(TORCH_SELECTIVE_SCHEMA("quantized::mul_out(Tensor qa, Tensor qb, Tensor(a!) out)-> Tensor(a!) out"));
Expand Down
5 changes: 5 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -997,7 +997,10 @@ def __init__(self, is_inplace, is_scalar):
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
# x = x + y
x = self.op(x, y)
# x = y + x
x = self.op(y, x)
return x

# TODO: decide whether we want to quantize or not
Expand Down Expand Up @@ -1040,6 +1043,8 @@ def forward(self, x, y):
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
x = self.relu(x)
x = self.op(y, x)
x = self.relu(x)
return x

data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
Expand Down
34 changes: 24 additions & 10 deletions torch/quantization/fx/quantization_patterns.py
Expand Up @@ -46,7 +46,8 @@ def __init__(self, quantizer, node):
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.all_nodes = True
self.num_node_args = len(node.args)
self.all_node_args = True

@abstractmethod
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
Expand All @@ -71,18 +72,24 @@ def __init__(self, quantizer, node):
node = node.args[0]
assert node.op == 'call_function' and node.target in [operator.add, torch.add]
self.add_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])
self.num_node_args = len([a for a in self.add_node.args[:2] if isinstance(a, Node)])

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if not self.all_nodes:
if self.num_node_args == 1:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add

if isinstance(self.add_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1

return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs)
load_arg(quantized=[quantized_index])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
Expand All @@ -96,6 +103,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)

# TODO: merge with Add
@register_quant_pattern(operator.mul)
@register_quant_pattern(torch.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
Expand All @@ -112,17 +120,23 @@ def __init__(self, quantizer, node):
node = node.args[0]
assert node.op == 'call_function' and node.target in [operator.mul, torch.mul]
self.mul_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])
self.num_node_args = len([a for a in self.mul_node.args[:2] if isinstance(a, Node)])

def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if not self.all_nodes:
if self.num_node_args == 1:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul

if isinstance(self.mul_node.args[0], Node):
quantized_index = 0
else:
quantized_index = 1

return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs)
'call_function', op, load_arg(quantized=[quantized_index])(self.mul_node.args), self.mul_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
Expand All @@ -138,7 +152,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if not self.all_nodes:
if not self.all_node_args:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
Expand Down Expand Up @@ -438,7 +452,7 @@ class DefaultNode(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_dict=None):
if not self.all_nodes:
if not self.all_node_args:
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
'call_function are handled in DefaultNode'
Expand Down Expand Up @@ -580,7 +594,7 @@ def convert(self, quantizer, node, load_arg, debug=False, convert_custom_config_
# of quantizable objects (e.g. modules and functionals)
class DefaultQuantizeHandler(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
assert self.all_node_args
root_module = quantizer.modules['']
return quantize_node(
root_module,
Expand Down
20 changes: 14 additions & 6 deletions torch/quantization/fx/quantize.py
Expand Up @@ -384,7 +384,7 @@ def load_arg(a):
continue

prefix = node.name + '_activation_post_process_'
root_node, _, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
root_node, matched_nodes, pattern, obj, qconfig = matches.get(node.name, (None, None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
Expand Down Expand Up @@ -458,15 +458,23 @@ def is_observed(input_arg):
# propagate observed property from input
if is_observed(node.args[0]):
observed_node_names_set.add(node.name)
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
if node.args[0].name in observed_node_names_set:
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and obj.num_node_args == 1:
input_node = matched_nodes[-1] # first node in the sequence

def input_is_observed(arg):
return isinstance(arg, Node) and arg.name in observed_node_names_set
# This is checking if one of the argument of add/mul
# is an observed node
# If both of the inputs are number,
# we will not consider the output to be observed
if input_is_observed(input_node.args[0]) or input_is_observed(input_node.args[1]):
observed_node_names_set.add(node.name)
elif isinstance(obj, StandaloneModuleQuantizeHandler):
assert node.op == 'call_module'
output_is_observed = self.modules[node.target]._output_is_observed
if output_is_observed:
observed_node_names_set.add(node.name)
elif qconfig is not None and obj.all_nodes:
elif qconfig is not None and obj.all_node_args:
# observer for outputs
new_observer = qconfig.activation()
# respect device affinity when adding observers
Expand Down Expand Up @@ -741,8 +749,8 @@ def is_quantized(node):
# the node is quantized in parent module
quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
else:
# dequantize inputs for the node that are not quantized
env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
# copy quantized or non-quantized node
env[node.name] = self.quantized_graph.node_copy(node, load_x)

# remove activation post process
act_post_process_removed_graph = Graph()
Expand Down

0 comments on commit 5a8198e

Please sign in to comment.