Skip to content

Commit

Permalink
[ONNX] Add support for QLinearConcat contrib op (apache#8907)
Browse files Browse the repository at this point in the history
* add qlinearconcat op

* fix tests

* Fix

* lint

* lint

* review

* boop ci

* fix regression

* noop

* jostle ci
  • Loading branch information
anwang2009 authored and ylc committed Jan 13, 2022
1 parent b643981 commit 5d8b03f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 45 deletions.
101 changes: 56 additions & 45 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,17 @@ def _dim_check(attrs):
return _dim_check, "Only 1d, 2d and 3d kernel supported."


def get_scalar(x, params, dtype="float32"):
"""Helper to get a scalar value for Quantized operators."""
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -3196,23 +3207,14 @@ class QLinearConv(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
def get_scalar(x, dtype="float32"):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "QLinearConv scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)

data = inputs[0]
x_scale = get_scalar(inputs[1])
x_zero_point = get_scalar(inputs[2], "int32")
x_scale = get_scalar(inputs[1], params)
x_zero_point = get_scalar(inputs[2], params, "int32")
weight = inputs[3]
w_scale = get_scalar(inputs[4])
w_zero_point = get_scalar(inputs[5], "int32")
y_scale = fold_constant(get_scalar(inputs[6]))
y_zero_point = get_scalar(inputs[7], "int32")
w_scale = get_scalar(inputs[4], params)
w_zero_point = get_scalar(inputs[5], params, "int32")
y_scale = fold_constant(get_scalar(inputs[6], params))
y_zero_point = get_scalar(inputs[7], params, "int32")

input_shape = infer_shape(data)

Expand Down Expand Up @@ -3300,23 +3302,14 @@ class QLinearAdd(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
def get_scalar(x, dtype="float32"):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "QLinearConv scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)

a = inputs[0]
a_scale = get_scalar(inputs[1])
a_zero_point = get_scalar(inputs[2], "int32")
a_scale = get_scalar(inputs[1], params)
a_zero_point = get_scalar(inputs[2], params, "int32")
b = inputs[3]
b_scale = get_scalar(inputs[4])
b_zero_point = get_scalar(inputs[5], "int32")
c_scale = get_scalar(inputs[6])
c_zero_point = get_scalar(inputs[7], "int32")
b_scale = get_scalar(inputs[4], params)
b_zero_point = get_scalar(inputs[5], params, "int32")
c_scale = get_scalar(inputs[6], params)
c_zero_point = get_scalar(inputs[7], params, "int32")

dtype = infer_type(a).checked_type.dtype

Expand All @@ -3338,23 +3331,14 @@ class QLinearMul(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
def get_scalar(x, dtype="float32"):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "QLinearMul scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)

a = inputs[0]
a_scale = get_scalar(inputs[1])
a_zero_point = get_scalar(inputs[2], "int32")
a_scale = get_scalar(inputs[1], params)
a_zero_point = get_scalar(inputs[2], params, "int32")
b = inputs[3]
b_scale = get_scalar(inputs[4])
b_zero_point = get_scalar(inputs[5], "int32")
y_scale = fold_constant(get_scalar(inputs[6]))
y_zero_point = get_scalar(inputs[7], "int32")
b_scale = get_scalar(inputs[4], params)
b_zero_point = get_scalar(inputs[5], params, "int32")
y_scale = fold_constant(get_scalar(inputs[6], params))
y_zero_point = get_scalar(inputs[7], params, "int32")

dtype = infer_type(a).checked_type.dtype

Expand All @@ -3367,6 +3351,32 @@ def get_scalar(x, dtype="float32"):
return _qnn.op.quantize(out, y_scale, y_zero_point, out_dtype=dtype)


class QLinearConcat(OnnxOpConverter):
"""Operator converter for QLinearConcat from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
# which axis to concat on
axis = attr["axis"]

y_scale = fold_constant(get_scalar(inputs[0], params))
y_zero_point = get_scalar(inputs[1], params, "int32")

# input tensors, scales, zero_points
assert (
len(inputs) % 3 == 2
), "Additional input count must be a multiple of 3 -- tensor/scale/zero_point tuples"
tensors = []
scales = []
zero_points = []
for i in range(2, len(inputs), 3):
tensors.append(inputs[i])
scales.append(get_scalar(inputs[i + 1], params))
zero_points.append(get_scalar(inputs[i + 2], params, "int32"))

return _qnn.op.concatenate(tensors, scales, zero_points, y_scale, y_zero_point, axis)


class ConvInteger(OnnxOpConverter):
"""Operator converter for ConvInteger."""

Expand Down Expand Up @@ -3748,6 +3758,7 @@ def _get_convert_map(opset):
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearConcat": QLinearConcat.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
"QLinearMul": QLinearMul.get_converter(opset),
"ConvInteger": ConvInteger.get_converter(opset),
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5358,6 +5358,39 @@ def repeat(N, D):
)


@tvm.testing.parametrize_targets
def test_qlinearconcat(target, dev):
def verify_qlinearconcat(shapes, out_shape, axis=None):
input_names = []
input_values = []
input_nodes = []
for i in range(len(shapes)):
tensor_name = chr(ord("a") + i)
shape = shapes[i]
node = helper.make_tensor_value_info(tensor_name, TensorProto.FLOAT, list(shape))

input_names.append(tensor_name)
input_values.append(np.random.random(shape).astype("float32"))
input_nodes.append(node)

node = helper.make_node("Concat", input_names, ["C"])
if axis is not None:
axis_attr = helper.make_attribute("axis", axis)
node.attribute.append(axis_attr)
graph = helper.make_graph(
[node],
"qlinearconcat_test",
inputs=input_nodes,
outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(out_shape))],
)
model = helper.make_model(graph, producer_name="qlinearconcat_test")
quantize_and_verify_with_ort(model, input_names, shapes, target, dev)

verify_qlinearconcat([[2, 1], [2, 1]], [4, 1], 0)
verify_qlinearconcat([[2, 1], [2, 1]], [2, 2], 1)
verify_qlinearconcat([[1, 2], [2, 2], [3, 2]], [6, 2], 0)


@tvm.testing.parametrize_targets
def test_qlinearadd(target, dev):
def verify_qlinearadd(a_shape, b_shape, c_shape):
Expand Down Expand Up @@ -5716,6 +5749,7 @@ def repeat(N, D):
test_index_put()
test_reverse_sequence()
test_eyelike()
test_qlinearconcat()
test_qlinearconv()
test_random_uniform()
test_convinteger()
Expand Down

0 comments on commit 5d8b03f

Please sign in to comment.