Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,18 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
is_transpose = node.args[6]
groups = cast(int, node.args[8])

# XNNPack does not support non-zero output padding in transposed
# convolutions.
if is_transpose and any(
out_pad != 0 for out_pad in cast(List[int], node.args[7])
):
why(
node,
"XNNPACK does not support transposed convolutions with"
"non-zero output padding",
)
return False

if (
is_transpose
and weight_quant_params is not None
Expand Down
39 changes: 39 additions & 0 deletions backends/xnnpack/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,42 @@ def get_inputs(self):
quant_config=None,
conv_count=1,
)

def test_padded_output_tconv(self):
class TConv2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.ConvTranspose2d(
in_channels=2,
out_channels=1,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
output_padding=(0, 1),
dilation=(1, 1),
groups=1,
bias=True,
).to(torch.float)

def forward(self, x):
return self.conv(x)

m = TConv2d()
inputs = (torch.randn(1, 2, 8, 8),)
tester = Tester(m.eval(), inputs)

conv_count: int = 1
op = "torch.ops.aten.conv_transpose2d"

(tester.export().check_count({op: conv_count}).to_edge_transform_and_lower())

# tconv should not be offloaded to XNNPack, since output padding is not
(
tester.check(
["executorch_exir_dialects_edge__ops_aten_convolution_default"]
)
.check_not(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
)
Loading