Skip to content

Commit

Permalink
[Frontend][PaddlePaddle] PaddlePaddle model with NCHW data format tha…
Browse files Browse the repository at this point in the history
…t supports quantization (apache#16651)

* support conv2d when data_format is NHWC

* modify the annotation

* Do not convert input data when processing quantization conv_2d nodes

* Fix code formatting issues

* fixed error code format

* update dequantize and quantize

* fixed bug when model is fp32 model

* update dequantize and quantize

* update for paddle quantize model when format is NCHW
  • Loading branch information
Zheng-Bicheng authored and thaisacs committed Apr 3, 2024
1 parent 9240277 commit a80aa9f
Showing 1 changed file with 74 additions and 9 deletions.
83 changes: 74 additions & 9 deletions python/tvm/relay/frontend/paddlepaddle.py
Expand Up @@ -31,6 +31,7 @@
from .. import function as _function
from .. import ty as _ty
from .. import op as _op
from .. import qnn as _qnn
from .common import (
autopad,
fold_constant,
Expand Down Expand Up @@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
strides = op.attr("strides")

kernel = g.get_node(op.input("Filter")[0])
kernel_layout = "OIHW"
input_x = g.get_node(op.input("Input")[0])
data_layout = op.attr("data_format")
kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
out_channels, _, k_h, k_w = infer_shape(kernel)
if padding_algorithm == "VALID":
paddings = [0, 0]
Expand All @@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."'
raise tvm.error.OpAttributeInvalid(msg)

if data_layout == "NHWC":
kernel_layout = "HWIO"
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
is_quantized = op.has_attr("quantization_type")
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
# There are two situations when converting the data format of weights:
# 1 Conv_2d is not a quantified OP, its weight information is the weights themselves.
# We directly convert the weight information when processing conv_2d.
# 2 Conv_2d is a quantified OP, and its weight information is the output of
# the quantize_linear operator. Therefore, the weight information needs to be
# transformed when processing the quantize_linear operator.
if (not is_quantized) and (data_layout == "NHWC"):
kernel_data = g.get_params(op.input("Filter")[0])
kernel_data = kernel_data.asnumpy()
kernel_data = kernel_data.transpose((2, 3, 1, 0))
Expand Down Expand Up @@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))

# handle with special case
# while kernel size less than input size
# while kernel size more than input size
# shrink kernel size to input size
if (
not isinstance(in_h, _op.Expr)
Expand Down Expand Up @@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_dequantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()

tvm_quantize_axis = op.attr("quant_axis")
if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

if len(infer_shape(data_node)) < 2:
tvm_quantize_axis = 0

out = _qnn.op.dequantize(
data=data_node,
input_scale=_op.const(tvm_quantize_scale, "float32"),
input_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_quantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_quantize_axis = op.attr("quant_axis")

if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

out = _qnn.op.quantize(
data=data_node,
output_scale=_op.const(tvm_quantize_scale, "float32"),
output_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_rnn(g, op, block):
"""Operator converter for rnn."""

Expand Down Expand Up @@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
def convert_softmax(g, op, block):
"""Operator converter for softmax."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
input_shape = block.var(op.input("X")[0]).shape
if axis < 0:
axis = len(input_shape) + axis
x = g.get_node(op.input("X")[0])
m = _op.max(x, axis, keepdims=True)
e = _op.exp(x - m)
out = e / _op.sum(e, axis, keepdims=True)
Expand Down Expand Up @@ -2905,6 +2965,9 @@ def convert_where_index(g, op, block):
"unstack": convert_unstack,
"where": convert_where,
"where_index": convert_where_index,
# Quantized
"dequantize_linear": convert_dequantize_linear,
"quantize_linear": convert_quantize_linear,
}


Expand Down Expand Up @@ -2938,7 +3001,7 @@ def get_params(self, name=None):

if name is None:
return self.params
assert name in self.params
assert name in self.params, f"The name({name}) is not in params"
return self.params[name]

def extract_parameters(self, program, scope=None):
Expand All @@ -2947,9 +3010,12 @@ def extract_parameters(self, program, scope=None):
self.params = {}
variables = program.global_block().vars
for name in variables:
var = program.global_block().var(name)
if name.endswith("feed") or name.endswith("fetch"):
continue
# This judgment will cause the PaddleInference model
# exported by PaddleSlim to skip some operators
# that need to be read in NHWC format.
var = program.global_block().var(name)
if not var.persistable:
continue
if isinstance(scope, dict):
Expand Down Expand Up @@ -3018,7 +3084,6 @@ def from_program(self, program, shape_dict, scope):
for op in block.ops:
if op.type == "fetch":
output_names.append(op.input("X")[0])

outputs = [self.nodes[name] for name in output_names]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)

Expand Down

0 comments on commit a80aa9f

Please sign in to comment.