Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 61 additions & 57 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
OpDepthWiseConv2d,
OpExpandDims,
OpReshape,
OpTransposeConv2d,
QNN_OP_PACKAGE_NAME_QTI_AISW,
)
from .utils import get_parameter
Expand All @@ -42,6 +43,9 @@ def _add_conv_op_parameter(
padding_shape,
dilation,
dilation_shape,
output_padding=None,
output_padding_shape=None,
transpose_conv=False,
groups=None,
) -> PyQnnWrapper.PyQnnOpWrapper:
"""
Expand All @@ -68,14 +72,26 @@ def _add_conv_op_parameter(
),
True,
)
conv_op.AddTensorParam(
OP.param_dilation,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(dilation_shape),
dilation_shape,
np.array(dilation, dtype=np.uint32),
True,
)

if transpose_conv:
conv_op.AddTensorParam(
OP.param_output_padding,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(output_padding_shape),
output_padding_shape,
np.array(output_padding, dtype=np.uint32),
True,
)
else:
conv_op.AddTensorParam(
OP.param_dilation,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
len(dilation_shape),
dilation_shape,
np.array(dilation, dtype=np.uint32),
True,
)

if groups is not None:
conv_op.AddScalarParam(
OP.param_group,
Expand All @@ -94,6 +110,11 @@ def _define_conv1d(
Conv1D is a special case for convolutional operation. QNN does not support Conv1D, therefore,
we need to cast from input -> Conv1d -> output to input -> unsqueeze -> Conv2d -> squeeze -> output.
"""
transpose_conv = cast(bool, node.args[6])
if transpose_conv:
print("ConvTranspose1d is not yet supported")
return

op_wrapper_list = [] # op_wrapper to return
unsqueeze_input_node = node.args[0]
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
Expand Down Expand Up @@ -239,9 +260,9 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:

if get_parameter(node.args[1], self.edge_program).dim() == 3:
return self._define_conv1d(node, nodes_to_wrappers)

input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
Expand All @@ -254,8 +275,9 @@ def define_node(

filter_node = node.args[1]
filter_tensor = get_parameter(filter_node, self.edge_program)
# weight of pytorch OIHW, yet QNN is HWIO
filter_axis_order = (2, 3, 1, 0)
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
is_transpose_conv = cast(bool, node.args[6])
filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0)
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
filter_tensor_wrapper = self.define_tensor(
filter_node,
Expand Down Expand Up @@ -291,6 +313,7 @@ def define_node(
stride = cast(List[int], node.args[3])
padding = cast(List[int], node.args[4])
dilation = cast(List[int], node.args[5])
output_padding = cast(List[int], node.args[7])

groups = cast(int, node.args[8])
# Qnn filter tensor is (H, W, Cin, Cout)
Expand All @@ -308,57 +331,38 @@ def define_node(
if len(padding) == 1:
padding = padding + padding

# args[6] = transposed
if cast(bool, node.args[6]):
print("Currently, No support for transposed convolution")
return

# args[7] = output padding
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
print("QNN does not support output padding")
return

stride_shape = [len(stride)]
padding_shape = [2, 2]
dilation_shape = [len(dilation)]
output_padding_shape = [len(output_padding)]

if is_depthwise_conv:
conv_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpDepthWiseConv2d.op_name,
)
conv_op = self._add_conv_op_parameter(
OpDepthWiseConv2d,
conv_op,
conv_input_tensors,
conv_output_tensors,
stride,
stride_shape,
padding,
padding_shape,
dilation,
dilation_shape,
)

op_class = OpDepthWiseConv2d
elif is_transpose_conv:
op_class = OpTransposeConv2d
else:
conv_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpConv2d.op_name,
)
conv_op = self._add_conv_op_parameter(
OpConv2d,
conv_op,
conv_input_tensors,
conv_output_tensors,
stride,
stride_shape,
padding,
padding_shape,
dilation,
dilation_shape,
groups,
)
op_class = OpConv2d

conv_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
op_class.op_name,
)
conv_op = self._add_conv_op_parameter(
op_class,
conv_op,
conv_input_tensors,
conv_output_tensors,
stride,
stride_shape,
padding,
padding_shape,
dilation,
dilation_shape,
output_padding,
output_padding_shape,
is_transpose_conv,
None if is_depthwise_conv else groups,
)

return conv_op
9 changes: 9 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,12 @@ class OpTile:
class OpTranspose:
op_name: str = "Transpose"
param_perm: str = "perm"


@dataclass(init=False, frozen=True)
class OpTransposeConv2d:
op_name: str = "TransposeConv2d"
param_stride: str = "stride"
param_pad_amount: str = "pad_amount"
param_group: str = "group"
param_output_padding: str = "output_padding"
8 changes: 7 additions & 1 deletion backends/qualcomm/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,13 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
node.meta["source_fn_stack"] = [(node, torch.bmm)]


@register_annotator([torch.ops.aten.conv2d.default, torch.ops.aten.conv1d.default])
@register_annotator(
[
torch.ops.aten.conv2d.default,
torch.ops.aten.conv1d.default,
torch.ops.aten.conv_transpose2d.input,
]
)
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
Expand Down
40 changes: 40 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,46 @@ def forward(self, x):
return self.conv(x)


class ConvTranspose2dSingle(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channels=1,
out_channels=3,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
)

def forward(self, x):
return self.conv_transpose(x)


class Conv2dDownUpSample(torch.nn.Module):
def __init__(self, bias=True):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
)
self.conv_transpose = torch.nn.ConvTranspose2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
)

def forward(self, x):
return self.conv_transpose(self.conv(x))


class Conv2dSumReduceDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
32 changes: 32 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def test_qnn_backend_conv2d(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv_transpose2d(self):
modules = [
ConvTranspose2dSingle(), # noqa: F405
ConvTranspose2dSingle(bias=False), # noqa: F405
]
sample_input = (torch.randn([1, 1, 3, 3]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_element_wise_add(self):
test_comb = [
{
Expand Down Expand Up @@ -521,6 +531,11 @@ def test_qnn_backend_conv2d_cat(self):
sample_input = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d_down_up_sample(self):
module = Conv2dDownUpSample() # noqa: F405
sample_input = (torch.randn(1, 16, 224, 224),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d_max_pool2d(self):
module = Conv2dMaxPool2d() # noqa: F405
sample_input = (torch.rand(1, 2, 14, 14),)
Expand Down Expand Up @@ -713,6 +728,17 @@ def test_qnn_backend_conv2d(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv_transpose2d(self):
modules = [
ConvTranspose2dSingle(), # noqa: F405
ConvTranspose2dSingle(bias=False), # noqa: F405
] # noqa: F405
sample_input = (torch.randn([1, 1, 3, 3]),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_element_wise_add(self):
test_comb = [
{
Expand Down Expand Up @@ -1157,6 +1183,12 @@ def test_qnn_backend_conv2d_cat(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d_down_up_sample(self):
module = Conv2dDownUpSample() # noqa: F405
sample_input = (torch.randn(1, 16, 224, 224),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d_max_pool2d(self):
module = Conv2dMaxPool2d() # noqa: F405
sample_input = (torch.rand(1, 2, 14, 14),)
Expand Down
Loading