Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Enable group_conv1d import through conv2d con…
Browse files Browse the repository at this point in the history
…version. (apache#8321)

* Enable group conv1d import through conv2d hack.

* remove silly commented out lines.
  • Loading branch information
Josh Fromm authored and ylc committed Sep 29, 2021
1 parent 860cacd commit 0b76568
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
28 changes: 26 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class Conv(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
# Use shape of input to determine convolution type.
data = inputs[0]
kernel = inputs[1]
input_shape = infer_shape(data)
ndim = len(input_shape)

Expand All @@ -473,13 +474,32 @@ def _impl_v1(cls, inputs, attr, params):
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
attr["pads"] = [0 for i in range(ndim - 2)]
elif attr["auto_pad"] == "NOTSET":
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
attr.pop("auto_pad")

# Check if the requested convolution is a group conv1d, if so convert it to conv2d.
# TODO(jwfromm) Remove once proper group_conv1d is supported.
group_conv1d = False
if dimension_picker("conv")(attr) == "conv1d" and attr.get("group") != 1:
group_conv1d = True
# Expand input from NCW to NCHW
data = _op.expand_dims(data, axis=2)
# Expand kernel from OIW to OIHW
kernel = _op.expand_dims(kernel, axis=2)
# Add new value to kernel_shape, strices, dilation, pads, if needed
attr["kernel_shape"] = [1] + list(attr["kernel_shape"])
if "strides" in attr:
attr["strides"] = [1] + list(attr["strides"])
if "dilations" in attr:
attr["dilations"] = [1] + list(attr["dilations"])
if "pads" in attr:
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]

out = AttrCvt(
op_name=dimension_picker("conv"),
transforms={
Expand All @@ -489,7 +509,11 @@ def _impl_v1(cls, inputs, attr, params):
"group": ("groups", 1),
},
custom_check=dimension_constraint(),
)([data, inputs[1]], attr, params)
)([data, kernel], attr, params)

# If this was a group_conv1d, squish output back to NCW.
if group_conv1d:
out = _op.squeeze(out, axis=[2])

use_bias = len(inputs) == 3
if use_bias:
Expand Down
20 changes: 18 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,7 @@ def verify_conv(
kernel_shape,
strides,
dilations,
group=1,
auto_pad="NOTSET",
unset_pad=False,
):
Expand All @@ -2422,7 +2423,7 @@ def verify_conv(
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
group=group,
)
elif padding is None:
## autopadding with unset default attributes
Expand All @@ -2438,6 +2439,7 @@ def verify_conv(
outputs=["y"],
# Default values for other attributes:
auto_pad=auto_pad,
group=group,
**kwargs,
)
else:
Expand All @@ -2449,7 +2451,7 @@ def verify_conv(
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
group=group,
pads=padding,
)

Expand Down Expand Up @@ -2559,6 +2561,20 @@ def repeat(N, D):
repeat(2, D),
)

# TODO(jwfromm): Merge with other tests once group_conv3d is supported.
for D in [1, 2]:
# Group Convolution
verify_conv(
(1, 8) + repeat(5, D),
(8, 1) + repeat(3, D),
(1, 8) + repeat(5, D),
2 * repeat(1, D),
repeat(3, D),
repeat(1, D),
repeat(1, D),
group=8,
)


def verify_convtranspose_with_padding(
x_shape,
Expand Down

0 comments on commit 0b76568

Please sign in to comment.